{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "orig_nbformat": 4,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.6 64-bit ('base': conda)"
  },
  "interpreter": {
   "hash": "7b4c34fa5edc2b5c200e84280da452af41185f443fff3e767a73b82cf30c2550"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from common.configs.path import paths\n",
    "from common.configs.stopwords import punc, stopwords\n",
    "from common.configs.tools import label_map\n",
    "\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import re\n",
    "\n",
    "# from gensim.test.utils import common_texts\n",
    "# from gensim.models.doc2vec import Doc2Vec, TaggedDocument\n",
    "\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "import pickle\n",
    "\n",
    "import nltk\n",
    "\n",
    "def save_model(model, path):\n",
    "    with open(path, 'wb') as f:\n",
    "        pickle.dump(model, f)    \n",
    "\n",
    "def load_model(path):\n",
    "    with open(path, 'rb') as f:\n",
    "        model = pickle.load(f)\n",
    "    return model\n",
    "\n",
    "def load_npy(path):\n",
    "    return np.load(path)\n",
    "\n",
    "con_tfidf = load_npy(paths['output'] / 'tfidf/tfidf_con_3000.npy')\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "source": [
    "train_df = pd.read_csv(paths['train_data'])\n",
    "test_df = pd.read_csv(paths['test_data'])\n",
    "text_ = pd.concat([train_df['text'], test_df['text']]).apply(lambda e: re.sub('|'.join(stopwords), '', e))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "train_df.shape"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(14009, 3)"
      ]
     },
     "metadata": {},
     "execution_count": 27
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "source": [
    "text_ = [nltk.word_tokenize(t) for t in text_]"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "source": [
    "docs = []\n",
    "\n",
    "for i, text in enumerate(text_):\n",
    "    tags = 'train'\n",
    "    if i >= 14009:\n",
    "        tags = 'test'\n",
    "    docs.append(TaggedDocument(words, [tags]))\n",
    "\n",
    "d2v_model = doc2vec.Doc2Vec(docs, vector_size = 300, window = 300, min_count = 1, workers = 4)\n",
    "\n",
    "# # Get the vectors\n",
    "\n",
    "# model.docvecs[0]\n",
    "# model.docvecs[1]\n"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "NameError",
     "evalue": "name 'TaggedDocument' is not defined",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-29-19f1197ab1b8>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m14009\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m         \u001b[0mtags\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'test'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m     \u001b[0mdocs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mTaggedDocument\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwords\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mtags\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[0md2v_model\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdoc2vec\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mDoc2Vec\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdocs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvector_size\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m300\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mwindow\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m300\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmin_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mworkers\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'TaggedDocument' is not defined"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import torch.nn as nn\n",
    "import torch"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "source": [
    "m = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, stride=2)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "source": [
    "x = torch.randn(3, 4, 2)\n",
    "x"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[[ 1.8414, -0.1491],\n",
       "         [-0.1021,  1.3026],\n",
       "         [-3.0290,  0.2963],\n",
       "         [ 0.0821, -0.7059]],\n",
       "\n",
       "        [[ 0.9560, -0.3225],\n",
       "         [-1.2947,  0.1147],\n",
       "         [ 1.3708,  1.3979],\n",
       "         [ 1.1273, -0.9528]],\n",
       "\n",
       "        [[-0.7996,  0.6791],\n",
       "         [ 0.9250, -0.9531],\n",
       "         [-0.6307,  1.0215],\n",
       "         [ 0.2368,  1.8031]]])"
      ]
     },
     "metadata": {},
     "execution_count": 20
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "source": [
    "sample_size = x.shape[0]\n",
    "channels = 2\n",
    "sequence_length = x.shape[1]"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "source": [
    "x_reshape = x.reshape(sample_size, channels, sequence_length)\n",
    "x_reshape.shape"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "torch.Size([3, 2, 4])"
      ]
     },
     "metadata": {},
     "execution_count": 22
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "source": [
    "m(x_reshape).shape"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "RuntimeError",
     "evalue": "Expected 4-dimensional input for 4-dimensional weight 1 2 3, but got 3-dimensional input of size [3, 2, 4] instead",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-23-9d218da4ad68>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_reshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    545\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    546\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    548\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    342\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2d_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    344\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    345\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mconv2d_forward\u001b[0;34m(self, input, weight)\u001b[0m\n\u001b[1;32m    338\u001b[0m                             _pair(0), self.dilation, self.groups)\n\u001b[1;32m    339\u001b[0m         return F.conv2d(input, weight, self.bias, self.stride,\n\u001b[0;32m--> 340\u001b[0;31m                         self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m    341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    342\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Expected 4-dimensional input for 4-dimensional weight 1 2 3, but got 3-dimensional input of size [3, 2, 4] instead"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "mpool = nn.MaxPool1d(2, )"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "source": [
    "mpool(m(x_reshape))"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "RuntimeError",
     "evalue": "Given input size: (1x1x1). Calculated output size: (1x1x0). Output size is too small",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-a160203b0ebc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_reshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    545\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    546\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    548\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/pooling.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     74\u001b[0m         return F.max_pool1d(input, self.kernel_size, self.stride,\n\u001b[1;32m     75\u001b[0m                             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpadding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdilation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mceil_mode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m                             self.return_indices)\n\u001b[0m\u001b[1;32m     77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/_jit_internal.py\u001b[0m in \u001b[0;36mfn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    132\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mif_true\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mif_false\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mif_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mif_false\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36m_max_pool1d\u001b[0;34m(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)\u001b[0m\n\u001b[1;32m    454\u001b[0m         \u001b[0mstride\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mannotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    455\u001b[0m     return torch.max_pool1d(\n\u001b[0;32m--> 456\u001b[0;31m         input, kernel_size, stride, padding, dilation, ceil_mode)\n\u001b[0m\u001b[1;32m    457\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    458\u001b[0m max_pool1d = boolean_dispatch(\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Given input size: (1x1x1). Calculated output size: (1x1x0). Output size is too small"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "source": [
    "input = torch.randn(10, 1, 7)\n",
    "input"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[[ 0.8161,  0.6422, -0.5168, -0.5121, -1.4209,  0.2685,  1.6642]],\n",
       "\n",
       "        [[ 0.3516,  0.0246,  0.3889, -0.5340,  2.0792, -0.3316,  0.9179]],\n",
       "\n",
       "        [[ 1.1610, -0.9727,  0.3785,  0.4059,  0.1005,  0.3808,  1.4097]],\n",
       "\n",
       "        [[-1.0969,  0.3066, -0.1505, -2.0358,  0.3460, -1.0105,  1.4101]],\n",
       "\n",
       "        [[-0.5502,  0.9587,  1.4179,  0.4418,  1.1144, -0.9469, -1.0101]],\n",
       "\n",
       "        [[ 0.1248,  0.4253,  0.6555, -0.0482,  0.2624,  0.6703,  1.1224]],\n",
       "\n",
       "        [[-0.5518, -1.9113,  0.1987,  0.0145, -1.2341, -1.9231, -1.6173]],\n",
       "\n",
       "        [[-0.1629,  1.7754,  1.3270, -0.1574, -1.7480, -1.5613, -0.0116]],\n",
       "\n",
       "        [[-0.7946, -0.1356,  0.8293,  1.0104,  1.0508,  1.0179, -0.4162]],\n",
       "\n",
       "        [[-1.0098,  2.1383, -0.0133, -1.1540,  2.1836,  0.1288, -0.0404]]])"
      ]
     },
     "metadata": {},
     "execution_count": 91
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "source": [
    "mpool(input)"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "RuntimeError",
     "evalue": "Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'self' (while checking arguments for max_pool1d)",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-115-bc2ebcc1c9ba>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    545\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    546\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    548\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/pooling.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     74\u001b[0m         return F.max_pool1d(input, self.kernel_size, self.stride,\n\u001b[1;32m     75\u001b[0m                             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpadding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdilation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mceil_mode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m                             self.return_indices)\n\u001b[0m\u001b[1;32m     77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/_jit_internal.py\u001b[0m in \u001b[0;36mfn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    132\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mif_true\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mif_false\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mif_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mif_false\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__doc__\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36m_max_pool1d\u001b[0;34m(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices)\u001b[0m\n\u001b[1;32m    454\u001b[0m         \u001b[0mstride\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mannotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    455\u001b[0m     return torch.max_pool1d(\n\u001b[0;32m--> 456\u001b[0;31m         input, kernel_size, stride, padding, dilation, ceil_mode)\n\u001b[0m\u001b[1;32m    457\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    458\u001b[0m max_pool1d = boolean_dispatch(\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'self' (while checking arguments for max_pool1d)"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "source": [
    "target"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([8, 2, 3])"
      ]
     },
     "metadata": {},
     "execution_count": 116
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "source": [
    "import numpy as np\n",
    "target = torch.from_numpy(np.array([0, 1, 2]))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "source": [
    "input = torch.randn(3, 3, requires_grad=True)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "source": [
    "loss = nn.CrossEntropyLoss()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "source": [
    "output = loss(input, target.long())"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "source": [
    "output.backward()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "source": [
    "output"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor(0.4898, grad_fn=<NllLossBackward>)"
      ]
     },
     "metadata": {},
     "execution_count": 45
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "source": [
    "from sklearn.datasets import load_iris"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "source": [
    "iris = load_iris()\n",
    "X = iris.data\n",
    "Y = iris.target"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "x, x_val, y, y_val = train_test_split(X, Y, test_size=0.33, random_state=42)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "source": [
    "y"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([1, 2, 1, 0, 2, 1, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 1, 2, 0, 1, 2, 0,\n",
       "       2, 2, 1, 1, 2, 1, 0, 1, 2, 0, 0, 1, 1, 0, 2, 0, 0, 1, 1, 2, 1, 2,\n",
       "       2, 1, 0, 0, 2, 2, 0, 0, 0, 1, 2, 0, 2, 2, 0, 1, 1, 2, 1, 2, 0, 2,\n",
       "       1, 2, 1, 1, 1, 0, 1, 1, 0, 1, 2, 2, 0, 1, 2, 2, 0, 2, 0, 1, 2, 2,\n",
       "       1, 2, 1, 1, 2, 2, 0, 1, 2, 0, 1, 2])"
      ]
     },
     "metadata": {},
     "execution_count": 122
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "source": [
    "x_train = x.reshape(-1, x.shape[1]).astype('float32')\n",
    "y_train = y\n",
    "\n",
    "x_val = x_val.reshape(-1, x_val.shape[1]).astype('float32')\n",
    "y_val = y_val"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "source": [
    "y_train"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([1, 2, 1, 0, 2, 1, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 1, 2, 0, 1, 2, 0,\n",
       "       2, 2, 1, 1, 2, 1, 0, 1, 2, 0, 0, 1, 1, 0, 2, 0, 0, 1, 1, 2, 1, 2,\n",
       "       2, 1, 0, 0, 2, 2, 0, 0, 0, 1, 2, 0, 2, 2, 0, 1, 1, 2, 1, 2, 0, 2,\n",
       "       1, 2, 1, 1, 1, 0, 1, 1, 0, 1, 2, 2, 0, 1, 2, 2, 0, 2, 0, 1, 2, 2,\n",
       "       1, 2, 1, 1, 2, 2, 0, 1, 2, 0, 1, 2])"
      ]
     },
     "metadata": {},
     "execution_count": 124
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "source": [
    "y.unsqueeze()"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "AttributeError",
     "evalue": "'numpy.ndarray' object has no attribute 'unsqueeze'",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-126-888f9634ba40>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'unsqueeze'"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "source": [
    "torch.from_numpy(y).squeeze()"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([1, 2, 1, 0, 2, 1, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 1, 2, 0, 1, 2, 0, 2, 2,\n",
       "        1, 1, 2, 1, 0, 1, 2, 0, 0, 1, 1, 0, 2, 0, 0, 1, 1, 2, 1, 2, 2, 1, 0, 0,\n",
       "        2, 2, 0, 0, 0, 1, 2, 0, 2, 2, 0, 1, 1, 2, 1, 2, 0, 2, 1, 2, 1, 1, 1, 0,\n",
       "        1, 1, 0, 1, 2, 2, 0, 1, 2, 2, 0, 2, 0, 1, 2, 2, 1, 2, 1, 1, 2, 2, 0, 1,\n",
       "        2, 0, 1, 2])"
      ]
     },
     "metadata": {},
     "execution_count": 129
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "source": [
    "torch.cat(, 1)"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "TypeError",
     "evalue": "cat(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-148-5f533bd2ebf7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: cat(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "source": [
    "torch.from_numpy(x)"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[5.7000, 2.9000, 4.2000, 1.3000],\n",
       "        [7.6000, 3.0000, 6.6000, 2.1000],\n",
       "        [5.6000, 3.0000, 4.5000, 1.5000],\n",
       "        [5.1000, 3.5000, 1.4000, 0.2000],\n",
       "        [7.7000, 2.8000, 6.7000, 2.0000],\n",
       "        [5.8000, 2.7000, 4.1000, 1.0000],\n",
       "        [5.2000, 3.4000, 1.4000, 0.2000],\n",
       "        [5.0000, 3.5000, 1.3000, 0.3000],\n",
       "        [5.1000, 3.8000, 1.9000, 0.4000],\n",
       "        [5.0000, 2.0000, 3.5000, 1.0000],\n",
       "        [6.3000, 2.7000, 4.9000, 1.8000],\n",
       "        [4.8000, 3.4000, 1.9000, 0.2000],\n",
       "        [5.0000, 3.0000, 1.6000, 0.2000],\n",
       "        [5.1000, 3.3000, 1.7000, 0.5000],\n",
       "        [5.6000, 2.7000, 4.2000, 1.3000],\n",
       "        [5.1000, 3.4000, 1.5000, 0.2000],\n",
       "        [5.7000, 3.0000, 4.2000, 1.2000],\n",
       "        [7.7000, 3.8000, 6.7000, 2.2000],\n",
       "        [4.6000, 3.2000, 1.4000, 0.2000],\n",
       "        [6.2000, 2.9000, 4.3000, 1.3000],\n",
       "        [5.7000, 2.5000, 5.0000, 2.0000],\n",
       "        [5.5000, 4.2000, 1.4000, 0.2000],\n",
       "        [6.0000, 3.0000, 4.8000, 1.8000],\n",
       "        [5.8000, 2.7000, 5.1000, 1.9000],\n",
       "        [6.0000, 2.2000, 4.0000, 1.0000],\n",
       "        [5.4000, 3.0000, 4.5000, 1.5000],\n",
       "        [6.2000, 3.4000, 5.4000, 2.3000],\n",
       "        [5.5000, 2.3000, 4.0000, 1.3000],\n",
       "        [5.4000, 3.9000, 1.7000, 0.4000],\n",
       "        [5.0000, 2.3000, 3.3000, 1.0000],\n",
       "        [6.4000, 2.7000, 5.3000, 1.9000],\n",
       "        [5.0000, 3.3000, 1.4000, 0.2000],\n",
       "        [5.0000, 3.2000, 1.2000, 0.2000],\n",
       "        [5.5000, 2.4000, 3.8000, 1.1000],\n",
       "        [6.7000, 3.0000, 5.0000, 1.7000],\n",
       "        [4.9000, 3.1000, 1.5000, 0.2000],\n",
       "        [5.8000, 2.8000, 5.1000, 2.4000],\n",
       "        [5.0000, 3.4000, 1.5000, 0.2000],\n",
       "        [5.0000, 3.5000, 1.6000, 0.6000],\n",
       "        [5.9000, 3.2000, 4.8000, 1.8000],\n",
       "        [5.1000, 2.5000, 3.0000, 1.1000],\n",
       "        [6.9000, 3.2000, 5.7000, 2.3000],\n",
       "        [6.0000, 2.7000, 5.1000, 1.6000],\n",
       "        [6.1000, 2.6000, 5.6000, 1.4000],\n",
       "        [7.7000, 3.0000, 6.1000, 2.3000],\n",
       "        [5.5000, 2.5000, 4.0000, 1.3000],\n",
       "        [4.4000, 2.9000, 1.4000, 0.2000],\n",
       "        [4.3000, 3.0000, 1.1000, 0.1000],\n",
       "        [6.0000, 2.2000, 5.0000, 1.5000],\n",
       "        [7.2000, 3.2000, 6.0000, 1.8000],\n",
       "        [4.6000, 3.1000, 1.5000, 0.2000],\n",
       "        [5.1000, 3.5000, 1.4000, 0.3000],\n",
       "        [4.4000, 3.0000, 1.3000, 0.2000],\n",
       "        [6.3000, 2.5000, 4.9000, 1.5000],\n",
       "        [6.3000, 3.4000, 5.6000, 2.4000],\n",
       "        [4.6000, 3.4000, 1.4000, 0.3000],\n",
       "        [6.8000, 3.0000, 5.5000, 2.1000],\n",
       "        [6.3000, 3.3000, 6.0000, 2.5000],\n",
       "        [4.7000, 3.2000, 1.3000, 0.2000],\n",
       "        [6.1000, 2.9000, 4.7000, 1.4000],\n",
       "        [6.5000, 2.8000, 4.6000, 1.5000],\n",
       "        [6.2000, 2.8000, 4.8000, 1.8000],\n",
       "        [7.0000, 3.2000, 4.7000, 1.4000],\n",
       "        [6.4000, 3.2000, 5.3000, 2.3000],\n",
       "        [5.1000, 3.8000, 1.6000, 0.2000],\n",
       "        [6.9000, 3.1000, 5.4000, 2.1000],\n",
       "        [5.9000, 3.0000, 4.2000, 1.5000],\n",
       "        [6.5000, 3.0000, 5.2000, 2.0000],\n",
       "        [5.7000, 2.6000, 3.5000, 1.0000],\n",
       "        [5.2000, 2.7000, 3.9000, 1.4000],\n",
       "        [6.1000, 3.0000, 4.6000, 1.4000],\n",
       "        [4.5000, 2.3000, 1.3000, 0.3000],\n",
       "        [6.6000, 2.9000, 4.6000, 1.3000],\n",
       "        [5.5000, 2.6000, 4.4000, 1.2000],\n",
       "        [5.3000, 3.7000, 1.5000, 0.2000],\n",
       "        [5.6000, 3.0000, 4.1000, 1.3000],\n",
       "        [7.3000, 2.9000, 6.3000, 1.8000],\n",
       "        [6.7000, 3.3000, 5.7000, 2.1000],\n",
       "        [5.1000, 3.7000, 1.5000, 0.4000],\n",
       "        [4.9000, 2.4000, 3.3000, 1.0000],\n",
       "        [6.7000, 3.3000, 5.7000, 2.5000],\n",
       "        [7.2000, 3.0000, 5.8000, 1.6000],\n",
       "        [4.9000, 3.6000, 1.4000, 0.1000],\n",
       "        [6.7000, 3.1000, 5.6000, 2.4000],\n",
       "        [4.9000, 3.0000, 1.4000, 0.2000],\n",
       "        [6.9000, 3.1000, 4.9000, 1.5000],\n",
       "        [7.4000, 2.8000, 6.1000, 1.9000],\n",
       "        [6.3000, 2.9000, 5.6000, 1.8000],\n",
       "        [5.7000, 2.8000, 4.1000, 1.3000],\n",
       "        [6.5000, 3.0000, 5.5000, 1.8000],\n",
       "        [6.3000, 2.3000, 4.4000, 1.3000],\n",
       "        [6.4000, 2.9000, 4.3000, 1.3000],\n",
       "        [5.6000, 2.8000, 4.9000, 2.0000],\n",
       "        [5.9000, 3.0000, 5.1000, 1.8000],\n",
       "        [5.4000, 3.4000, 1.7000, 0.2000],\n",
       "        [6.1000, 2.8000, 4.0000, 1.3000],\n",
       "        [4.9000, 2.5000, 4.5000, 1.7000],\n",
       "        [5.8000, 4.0000, 1.2000, 0.2000],\n",
       "        [5.8000, 2.6000, 4.0000, 1.2000],\n",
       "        [7.1000, 3.0000, 5.9000, 2.1000]], dtype=torch.float64)"
      ]
     },
     "metadata": {},
     "execution_count": 149
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import numpy as np"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "a = np.array([1,2,3,3])\n",
    "a[np.where(a != 3)]"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([1, 2])"
      ]
     },
     "metadata": {},
     "execution_count": 5
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "np.where(a != 3)"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(array([0, 1], dtype=int64),)"
      ]
     },
     "metadata": {},
     "execution_count": 6
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "b = np.array([[1,2,3,4],\n",
    "[2,3,4,5],\n",
    "[3,4,5,6],[4,5,6,7]])"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "source": [
    "mask = np.logical_or(a == 3, a == 2, a==1)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "source": [
    "b[mask]"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([[2, 3, 4, 5],\n",
       "       [3, 4, 5, 6],\n",
       "       [4, 5, 6, 7]])"
      ]
     },
     "metadata": {},
     "execution_count": 17
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "source": [
    "np.unique(a)"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([1, 2, 3])"
      ]
     },
     "metadata": {},
     "execution_count": 15
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ]
}